import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo


def fixed_padding(inputs, kernel_size, dilation):
    kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
    pad_total = kernel_size_effective - 1
    pad_beg = pad_total // 2
    pad_end = pad_total - pad_beg
    padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
    return padded_inputs


class SeparableConv2d(nn.Module):

    def __init__(self,
                 inplanes,
                 planes,
                 kernel_size=3,
                 stride=1,
                 dilation=1,
                 bias=False):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(
            inplanes,
            inplanes,
            kernel_size,
            stride,
            0,
            dilation,
            groups=inplanes,
            bias=bias)
        self.bn = nn.BatchNorm2d(inplanes)
        self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)

    def forward(self, x):
        x = fixed_padding(
            x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
        x = self.conv1(x)
        x = self.bn(x)
        x = self.pointwise(x)
        return x


# encoder block
class Block(nn.Module):

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 start_with_relu=True):
        super(Block, self).__init__()

        if planes != inplanes or stride != 1:
            self.skip = nn.Conv2d(
                inplanes, planes, 1, stride=stride, bias=False)
            self.skipbn = nn.BatchNorm2d(planes)
        else:
            self.skip = None
        first_conv = []
        rep = []
        # Deep SeparableConv1
        if start_with_relu:
            first_conv.append(nn.ReLU())
            first_conv.append(
                SeparableConv2d(inplanes, planes // 4, 3, 1, dilation))
            first_conv.append(nn.BatchNorm2d(planes // 4))
            first_conv.append(nn.ReLU())
        if not start_with_relu:
            first_conv.append(
                SeparableConv2d(inplanes, planes // 4, 3, 1, dilation))
            first_conv.append(nn.BatchNorm2d(planes // 4))
            first_conv.append(nn.ReLU())

        rep.append(SeparableConv2d(planes // 4, planes // 4, 3, 1, dilation))
        rep.append(nn.BatchNorm2d(planes // 4))

        if stride != 1:
            rep.append(nn.ReLU())
            rep.append(SeparableConv2d(planes // 4, planes, 3, 2))
            rep.append(nn.BatchNorm2d(planes))

        if stride == 1:
            rep.append(nn.ReLU())
            rep.append(SeparableConv2d(planes // 4, planes, 3, 1))
            rep.append(nn.BatchNorm2d(planes))

        self.first_conv = nn.Sequential(*first_conv)
        self.rep = nn.Sequential(*rep)

    def forward(self, inp):
        x = self.first_conv(inp)
        x = self.rep(x)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x = x + skip

        return x


class enc(nn.Module):
    """
    encoders:
    stage:stage=X ,where X means encX,example: stage=2 that means you defined the encoder enc2
    """

    def __init__(self, in_channels, out_channels, stage):
        super(enc, self).__init__()
        if (stage == 2 or stage == 4):
            rep_nums = 4
        elif (stage == 3):
            rep_nums = 6
        rep = []
        rep.append(
            Block(in_channels, out_channels, stride=2, start_with_relu=False))
        for i in range(rep_nums - 1):
            rep.append(
                Block(
                    out_channels, out_channels, stride=1, start_with_relu=True))

        self.reps = nn.Sequential(*rep)

    def forward(self, lp):
        x = self.reps(lp)
        return x


class fcattention(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(fcattention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(in_channels, 1000, bias=False),
            # nn.ReLU(inplace=True),
        )

        self.conv = nn.Sequential(
            nn.Conv2d(1000, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # print(y.size())
        y = self.fc(y).view(b, 1000, 1, 1)
        # print(y.size())
        y = self.conv(y)
        return x * y.expand_as(x)


class xceptionA(nn.Module):
    """
    """

    def __init__(self, num_classes):
        super(xceptionA, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=8,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False), nn.BatchNorm2d(num_features=8), nn.ReLU())
        self.enc2 = enc(in_channels=8, out_channels=48, stage=2)
        self.enc3 = enc(in_channels=48, out_channels=96, stage=3)
        self.enc4 = enc(in_channels=96, out_channels=192, stage=4)
        self.fca = fcattention(192, 192)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(192, 1000)

    def forward(self, x):
        b, c, _, _ = x.size()
        x = self.conv1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        x = self.fca(x)
        x = self.avg_pool(x).view(b, -1)
        # print(x.size())
        x = self.fc(x)
        return x


if __name__ == '__main__':
    net = xceptionA(num_classes=19)

    input = torch.randn(4, 3, 1024, 1024)
    print(net)
    torch.save(net.state_dict(), "backbone.pth")
    outputs = net(input)
    print(outputs.size())
